import os
import torch
from torch import nn
import numpy as np

# Define the MLP model (must match the architecture used during training)
# class MLP(nn.Module):
#     def __init__(self, input_dim, output_dim, hidden_dim=1024):
#         super(MLP, self).__init__()
#         self.input_dim = input_dim
#         hidden_dim = hidden_dim * 2
#         self.layers = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.GELU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.GELU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.GELU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.GELU(),
#             nn.Linear(hidden_dim, output_dim),
#         )
    
#     def forward(self, x):
#         return self.layers(x/np.sqrt(self.input_dim))
        

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=1024):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        hidden_dim = hidden_dim * 2
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim),
        )
    
    def forward(self, x):
        return self.layers(x/np.sqrt(self.input_dim))

class GemmaEmbeddingPredictor:
    def __init__(self, input_dim, mlp_model_path="gemma_mlp_model.pt"):
        """
        Initialize the embedding predictor by loading both the Gemma model and the trained MLP
        
        Args:
            mlp_model_path: Path to the trained MLP model weights
            gemma_model_name: Name or path of the Gemma model
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Get embedding dimensions from the model
        self.input_dim = input_dim
        self.output_dim = input_dim
        
        # Load the MLP model
        print(f"Loading MLP model from {mlp_model_path}")
        self.mlp_model = self._load_mlp_model(mlp_model_path)
        self.mlp_model.eval()  # Set to evaluation mode
    
    def _load_mlp_model(self, model_path):
        """
        Load the trained MLP model
        
        Args:
            model_path: Path to the saved model weights
            
        Returns:
            Loaded MLP model
        """
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")
        
        # Create model with the correct dimensions
        model = MLP(self.input_dim, self.output_dim).to(self.device).float()
        
        # Load weights
        model.load_state_dict(torch.load(model_path, map_location=self.device))
        return model
    
    def predict_next_token_embedding(self, last_token_embedding):
        """
        Predict the expected next token embedding using the trained MLP
        
        Args:
            prompt: Input text prompt
            
        Returns:
            Predicted expected embedding of the next token
        """
        # Predict the expected next token embedding
        with torch.no_grad():
            predicted_embedding = self.mlp_model(last_token_embedding.to(self.device))
        
        return predicted_embedding